{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pip install effdet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import cv2\n",
    "import torchvision.ops as ops\n",
    "from torchvision import transforms\n",
    "import numpy as np\n",
    "from pathlib import Path\n",
    "from effdet import create_model\n",
    "from effdet.config import get_efficientdet_config\n",
    "from boxmot import BotSort\n",
    "from boxmot.utils.ops import letterbox\n",
    "\n",
    "\n",
    "# Load EfficientDet model\n",
    "device = torch.device('mps')  # Use 'cuda' if you have a GPU\n",
    "\n",
    "model_name = 'tf_efficientdet_d0'  # You can choose a different variant like 'tf_efficientdet_d3'\n",
    "config = get_efficientdet_config(model_name)\n",
    "model = create_model(model_name, bench_task='predict', pretrained=True).to(device)\n",
    "model.eval()\n",
    "\n",
    "# Initialize the tracker\n",
    "tracker = BotSort(\n",
    "    reid_weights=Path('osnet_x0_25_msmt17.pt'),  # Path to ReID model\n",
    "    device=device,  # Use CPU for inference\n",
    "    half=False\n",
    ")\n",
    "\n",
    "input_size = config.image_size\n",
    "\n",
    "preprocess = transforms.Compose([\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n",
    "])\n",
    "\n",
    "# Open the video file\n",
    "vid = cv2.VideoCapture(0)  # or 'path/to/your.avi'\n",
    "\n",
    "while True:\n",
    "    # Capture frame-by-frame\n",
    "    ret, frame = vid.read()\n",
    "\n",
    "    # If ret is False, it means we have reached the end of the video\n",
    "    if not ret:\n",
    "        break\n",
    "\n",
    "    # Apply letterbox resizing\n",
    "    frame_letterbox, ratio, (dw, dh) = letterbox(frame, new_shape=input_size, auto=False, scaleFill=True)\n",
    "    \n",
    "    # Preprocess frame for EfficientDet (resize and normalize)\n",
    "    frame_tensor = preprocess(frame_letterbox).unsqueeze(0).to(device)\n",
    "\n",
    "    # Perform detection\n",
    "    with torch.no_grad():\n",
    "        detections = model(frame_tensor)[0]\n",
    "                \n",
    "    # Assuming detections is shaped [100, 6], with [x1, y1, x2, y2, confidence, class]\n",
    "    confidence_threshold = 0.5\n",
    "    \n",
    "    # Filter detections based on confidence threshold\n",
    "    mask = detections[:, 4] >= confidence_threshold\n",
    "    filtered_dets = detections[mask]\n",
    "\n",
    "    # Rescale coordinates from letterbox back to the original frame size\n",
    "    filtered_dets[:, 0] = (filtered_dets[:, 0] - dw) / ratio[0]\n",
    "    filtered_dets[:, 1] = (filtered_dets[:, 1] - dh) / ratio[1]\n",
    "    filtered_dets[:, 2] = (filtered_dets[:, 2] - dw) / ratio[0]\n",
    "    filtered_dets[:, 3] = (filtered_dets[:, 3] - dh) / ratio[1]\n",
    "\n",
    "    # Convert class to integer and stack results\n",
    "    dets = torch.cat((filtered_dets[:, :5], filtered_dets[:, 5].unsqueeze(1).int()), dim=1)\n",
    "\n",
    "    # Convert to numpy array (N X (x, y, x, y, conf, cls))\n",
    "    dets = dets.cpu().numpy()\n",
    "\n",
    "    # Update the tracker\n",
    "    res = tracker.update(dets, frame)  # --> M X (x, y, x, y, id, conf, cls, ind)\n",
    "\n",
    "    # Plot tracking results on the image\n",
    "    tracker.plot_results(frame, show_trajectories=True)\n",
    "\n",
    "    # Display the frame\n",
    "    cv2.imshow('BoXMOT + EfficientDet', frame)\n",
    "\n",
    "    # Simulate wait for key press to continue, press 'q' to exit\n",
    "    key = cv2.waitKey(1) & 0xFF\n",
    "    if key == ord('q'):\n",
    "        break\n",
    "\n",
    "# Release resources\n",
    "vid.release()\n",
    "cv2.destroyAllWindows()\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "boxmot-YDNZdsaB-py3.11",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
